{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import cv2\n",
    "import random\n",
    "import time\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28140, 784)\n",
      "(13860, 784)\n"
     ]
    }
   ],
   "source": [
    "raw_data = pd.read_csv('../data/train.csv',header=0)\n",
    "data = raw_data.values\n",
    "imgs = data[:, 1:]\n",
    "labels = data[:, 0]\n",
    "# 选取 2/3 数据作为训练集， 1/3 数据作为测试集\n",
    "train_features, test_features, train_labels, test_labels = train_test_split(imgs, labels, test_size=0.33, random_state=23323)\n",
    "\n",
    "print(train_features.shape)\n",
    "print(test_features.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 二值化\n",
    "def binaryzation(img):\n",
    "    cv_img = img.astype(np.uint8)\n",
    "    cv2.threshold(cv_img, 50, 1, cv2.THRESH_BINARY_INV, cv_img)\n",
    "    return cv_img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "    cv2.threshold(cv_img, 50, 1, cv2.THRESH_BINARY_INV, cv_img)\n",
    "这句代码中，cv_img是输入的784个pixel数字（0~255），50表示阈值，1表示最大值，cv2.THRESH_BINARY_INV表示二值化的类型。这句代码表示pixel数字大于50的部分，为1，小于50的部分，为0。\n",
    "\n",
    "看一下经过二值化处理后是什么效果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  63,\n",
       "       255, 253, 253, 244, 120,  22,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  12,\n",
       "       100, 209, 253, 252, 252, 252, 252, 187,   6,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "       144, 217, 252, 161, 253, 183, 153, 106, 218, 252,  70,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "        87, 180, 242, 243, 202,  68,  10,   3,   0,   0,  60, 194,  31,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   5, 184, 252, 226,  93,  23,   0,   0,   0,   0,   0,  32,\n",
       "       142, 179,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0, 195, 252, 183,  29,   0,   0,   0,   0,   0,   0,\n",
       "         0, 141, 252,  45,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,  48, 247, 173,  38,   0,   0,   0,   0,   0,\n",
       "         0,   0,  26, 245, 252,  74,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0, 100, 229,  72,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0, 132, 252, 252, 131,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,  26, 153,  27,   0,   0,\n",
       "         0,   0,   0,   0,   0,  34, 132, 252, 252,  98,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 242, 159,\n",
       "       111,  58,  68,  77,   0,  15,  34,  14, 180, 252, 252,  21,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "       181, 253, 253, 253, 253, 114,   0,   0,   0, 100, 253, 253, 141,\n",
       "        10,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,  50, 229, 252, 249, 120,  20,   0,   0,   0, 176, 252,\n",
       "       252,  55,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,  29,  44,  42,   0,   0,   0,   0,   0,\n",
       "       209, 252, 206,  10,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         3, 128, 251, 252,  92,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,  58, 252, 252, 238,  31,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,  39, 230, 252, 252, 143,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0, 116, 253, 252, 252,  20,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,  14, 226, 253, 252, 172,   4,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0, 159, 252, 253, 232,  30,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,  20, 216, 252, 253,\n",
       "       186,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
       "         0,   0,   0,   0])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n",
       "       0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "       0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
       "       0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,\n",
       "       0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "       0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "       0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1], dtype=uint8)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "binaryzation(trainset[0])     # 图片二值化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trainset, train_labels = train_features, train_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class_num = 10\n",
    "feature_len = 784"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10,)\n",
      "(10, 784, 2)\n"
     ]
    }
   ],
   "source": [
    "# 存放先验概率\n",
    "prior_probability = np.zeros(class_num)      \n",
    "print(prior_probability.shape)\n",
    "# 存放条件概率\n",
    "conditional_probability = np.zeros((class_num, feature_len, 2))    \n",
    "print(conditional_probability.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "具体文章参考这一篇[机器学习通俗入门-朴素贝叶斯分类器](http://blog.csdn.net/TaiJi1985/article/details/73657994)\n",
    "\n",
    "$x^{(i)}$ 为一个28维的向量表示第i个样本， $y^{(i)}$ 为标注的类别。我们求解的目标是：\n",
    "\n",
    "$$f = \\underset{j}{arg maxP} (y^{(i)} = j \\mid x^{(i)}) $$\n",
    "\n",
    "简单说就是计算$p (y^{(i)} = 0 \\mid x^{(i)}) $，$p (y^{(i)} = 1 \\mid x^{(i)}) $ ... $p (y^{(i)} = 9 \\mid x^{(i)}) $，从中找出一个最大的，如果从属于第j个类的概率最大，那么就认为这张图片从属于j这个类。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 计算先验概率及条件概率\n",
    "for i in range(len(train_labels)):\n",
    "    img = binaryzation(trainset[i])     # 图片二值化\n",
    "    label = train_labels[i]\n",
    "\n",
    "    prior_probability[label] += 1      # 每个label的图片各有多少个\n",
    "\n",
    "    for j in range(feature_len):\n",
    "        conditional_probability[label][j][img[j]] += 1  \n",
    "        # img[j]表示在像素点j上的值。如果是0，就会给第一个位置+1，如果是1，会给第二个位置+1\n",
    "        # 比如下面的conditional_probability[0][0]，结果是[0, 2711]。\n",
    "        # 说明在img中，标签为0的样本中，像素点为0的对应位置，在img中分别为0或1的数量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2711.,  3197.,  2828.,  2897.,  2751.,  2565.,  2769.,  2964.,\n",
       "        2654.,  2804.])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prior_probability"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "把上面的循环拆解分析一下，这里取第一个训练样本：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img1 = binaryzation(trainset[0])     # 图片二值化, 784个像素（feature）中，要么是0，要么是1\n",
    "label1 = train_labels[0]\n",
    "label1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img1[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([    0.,  2711.])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(conditional_probability[0][0]) # "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "说明在img中，标签为0的样本中，像素点为500的对应位置，在img中为0的样本数量是199，为1的样本数量是2512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[  199.  2512.]\n"
     ]
    }
   ],
   "source": [
    "print(conditional_probability[0][500]) # "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面之所以将概率归到[1.10001]，是因为上面所有关于概率的部分都是直接用样本数量，而不是实际的概率来记录的。这么做应该是为了在工程上解决内存，但是这种工程上的优化，对于理解书中的公式造成了影响。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "而且下面计算概率的时候有点问题：\n",
    "            probalility_0 = (float(pix_0)/float(pix_0+pix_1))*1000000 + 1\n",
    "分母部分是，属于i类(0~9)的图像中，像素j的数量……对啊，这个像素j的数量其实就是pix_0和pix_1的和，即属于i类的图像的数量。看来这里没问题，是我想多了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 将概率归到[1.10001]\n",
    "for i in range(class_num):\n",
    "    for j in range(feature_len):\n",
    "\n",
    "        # 经过二值化后图像只有0，1两种取值\n",
    "        pix_0 = conditional_probability[i][j][0]  # 属于i类(0~9)的图像中，像素j(0~783)为0的数量\n",
    "        pix_1 = conditional_probability[i][j][1]  # 属于i类(0~9)的图像中，像素j(0~783)为1的数量\n",
    "\n",
    "        # 计算0，1像素点对应的条件概率\n",
    "        probalility_0 = (float(pix_0)/float(pix_0+pix_1))*1000000 + 1\n",
    "        probalility_1 = (float(pix_1)/float(pix_0+pix_1))*1000000 + 1\n",
    "\n",
    "        conditional_probability[i][j][0] = probalility_0\n",
    "        conditional_probability[i][j][1] = probalility_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  1.00000000e+00,   1.00000100e+06])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conditional_probability[0][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "得到了prior_probability和conditional_probability，这就算是训练结束了。\n",
    "\n",
    "# test (predict)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 784)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 为了加快预测速度，这里直接取100个测试样本\n",
    "\n",
    "testset = test_features[:100]\n",
    "testset.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "$$p (y^{(i)} = j \\mid x_{k}^{(i)})  = \\frac{p (x_{k}^{(i)} \\mid y^{(i)} = j) \\cdot p(y^{(i)} = j)}{p(x_{k}^{(i)})}$$\n",
    "\n",
    "$p (y^{(i)} = j \\mid x_{k}^{(i)}) $中，$y^{(i)} = j$表示从属于哪一类，$x_{k}^{(i)}$表示哪一个像素点。\n",
    "\n",
    "下面calculate_probability函数就是在计算分子部分。\n",
    "\n",
    "`probability *= int(conditional_probability[label][i][img[i]])`\n",
    "\n",
    "这行代码中：\n",
    "- probability表示先验概率 $p(y^{(i)} = j)$\n",
    "- `conditional_probability[label][i][img[i]]`表示 $p (x_{k}^{(i)} \\mid y^{(i)} = j) $\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 计算不同标签下，testdata的概率\n",
    "def calculate_probability(img, label):\n",
    "    probability = int(prior_probability[label])\n",
    "\n",
    "    for i in range(len(img)):\n",
    "        probability *= int(conditional_probability[label][i][img[i]])\n",
    "\n",
    "    return probability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "predict = []\n",
    "\n",
    "for img in testset:\n",
    "\n",
    "    # 图像二值化\n",
    "    img = binaryzation(img)\n",
    "\n",
    "    max_label = 0\n",
    "    max_probability = calculate_probability(img, 0)\n",
    "\n",
    "    for j in range(1, 10):\n",
    "        probability = calculate_probability(img, j)\n",
    "\n",
    "        if max_probability < probability:\n",
    "            max_label = j\n",
    "            max_probability = probability\n",
    "\n",
    "    predict.append(max_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_predict = np.array(predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.76000000000000001"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "score = accuracy_score(test_labels[:100], test_predict)\n",
    "score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 重构朴素贝叶斯算法\n",
    "\n",
    "![](https://pic1.zhimg.com/v2-e17426fd0627560f1fc82118dd1d5d14_r.jpg)\n",
    "\n",
    "朴素贝叶斯认为所有特征都是独立的，然后得出一个样本出现的概率使其所有特征出现概率的联乘。\n",
    "\n",
    "首先求每一个标签的先验概率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10,)\n",
      "(10, 784, 2)\n"
     ]
    }
   ],
   "source": [
    "class_num = 10\n",
    "feature_len = 784\n",
    "\n",
    "# 存放每个label的数量\n",
    "class_number = np.zeros(class_num)      \n",
    "\n",
    "# 存放先验概率\n",
    "prior_probability = np.zeros(class_num)      \n",
    "print(prior_probability.shape)\n",
    "# 存放条件概率\n",
    "conditional_probability = np.zeros((class_num, feature_len, 2))    \n",
    "print(conditional_probability.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 计算先验概率\n",
    "for i in range(len(train_labels)):\n",
    "    img = binaryzation(trainset[i])     # 图片二值化\n",
    "    label = train_labels[i]\n",
    "\n",
    "    class_number[label] += 1      # 每个label的图片各有多少个\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.09633973,  0.11361052,  0.10049751,  0.10294954,  0.09776119,\n",
       "        0.09115139,  0.09840085,  0.10533049,  0.09431414,  0.09964463])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_number/len(train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2711.,  3197.,  2828.,  2897.,  2751.,  2565.,  2769.,  2964.,\n",
       "        2654.,  2804.])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_number"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "prior_probability = class_number / len(train_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算条件概率: \n",
    "\n",
    "$$p (X^{(i)} = a_{jl} \\mid Y = c_k)$$\n",
    "\n",
    "在标签为$c_k$的前提下，样本x的第$j$个特征（像素点）的第$l$个值（经过二值化处理，这里的$l$只有0或1两种可能）。conditional_probability的维度是`(10, 784, 2)`，最后的那个2，指的就是每个特征可以取的值。如果不做二值化处理，那么每个像素点应该有256种取值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 条件概率\n",
    "conditional_probability = np.zeros((class_num, feature_len, 2))    \n",
    "\n",
    "for i in range(len(train_labels)):\n",
    "    img = binaryzation(trainset[i])     # 图片二值化\n",
    "    label = train_labels[i]\n",
    "    for j in range(feature_len):\n",
    "        conditional_probability[label][j][img[j]] += 1   # 这里只得到a_jl的数量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  199.,  2512.])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conditional_probability[0][500] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2711.0"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_number[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.07340465,  0.92659535])"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conditional_probability[0][500] / class_number[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "conditional_probability_fraction = np.zeros((class_num, feature_len, 2))    \n",
    "\n",
    "for i in range(len(train_labels)):\n",
    "    label = train_labels[i]\n",
    "    for j in range(feature_len):\n",
    "        conditional_probability_fraction[label][j] = conditional_probability[label][j] / class_number[label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.07340465,  0.92659535])"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conditional_probability_fraction[0][500]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "发现上面如果分开两循环写的话冗长，这里还是应该写在一起："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 计算先验概率及条件概率\n",
    "for i in range(len(train_labels)):\n",
    "    img = binaryzation(trainset[i])     # 图片二值化\n",
    "    label = train_labels[i]\n",
    "\n",
    "    class_number[label] += 1      # 每个label的图片各有多少个\n",
    "    prior_probability = class_number / len(train_labels)\n",
    "\n",
    "    for j in range(feature_len):\n",
    "        conditional_probability[label][j][img[j]] += 1  \n",
    "        # 在所有训练样本中，标签=0的样本中，像素点=0的对应位置上，一共有多少个样本是0，一共有多少个样本是1\n",
    "        \n",
    "# 推荐概率        \n",
    "for i in range(class_num):\n",
    "    for j in range(feature_len):\n",
    "        conditional_probability[label][j] = conditional_probability[label][j] / class_number[label]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上面就算完成了第一步，计算完了先验概率和条件概率。接下来第二步对测试集进行预测:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 784)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testset.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 写一个函数来计算每一个标签下，对应的概率\n",
    "def calculate_probability(img, label):\n",
    "    probability = prior_probability[label] # 先验概率\n",
    "\n",
    "    # 对每一个像素点进行迭代，计算在laebl固定的情况下，每一个像素点的概率，然后联乘\n",
    "    for i in range(len(img)):\n",
    "        probability *= conditional_probability[label][i][img[i]] \n",
    "        # [i]表示一个测试样本中，第i个像素点\n",
    "        # img[i]表示一个测试样本中，第i个像素点是0还是1\n",
    "\n",
    "    return probability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/xu/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/__main__.py:7: RuntimeWarning: overflow encountered in double_scalars\n",
      "/Users/xu/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/__main__.py:7: RuntimeWarning: invalid value encountered in double_scalars\n"
     ]
    }
   ],
   "source": [
    "predict = []\n",
    "\n",
    "for img in testset:\n",
    "    img = binaryzation(img)\n",
    "    \n",
    "    max_label = 0\n",
    "    max_probability = calculate_probability(img, 0)\n",
    "    \n",
    "    for j in range(1, 10):\n",
    "        probability = calculate_probability(img, j)\n",
    "        \n",
    "        if max_probability < probability:\n",
    "            max_label = j\n",
    "            max_probability = probability\n",
    "\n",
    "    predict.append(max_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "看来确实是这样，为了防止溢出，所以源代码里才一直用数量代替。\n",
    "\n",
    "看来这个算法不需要我改了，源代码其实已经考虑了溢出的问题。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [py35]",
   "language": "python",
   "name": "Python [py35]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
